import numpy as np
import matplotlib.pyplot as plt


def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length


def _lr_adjuster(step):
    warmup_length = 300
    base_lr = 1e-4
    steps = 3125

    if step < warmup_length:
        lr = _warmup_lr(base_lr, warmup_length, step)
    else:
        e = step - warmup_length
        es = steps - warmup_length
        lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
    return lr


def main():
    lrs = []
    steps = [i for i in range(0, 3125)]

    for i in range(0, 3125):
        lr = _lr_adjuster(i+1)
        lrs.append(lr)

    plt.plot(steps, lrs)
    plt.show()


main()
